import torch
import torch.nn as nn


def channel_shuffle(x, groups):
    batch_size, num_channels, height, width = x.size()
    channels_per_group = num_channels // groups
    x = x.view(batch_size, groups, channels_per_group, height, width)
    x = x.transpose(1, 2).contiguous().view(batch_size, num_channels, height, width)
    return x


def depthwise_conv(input_c, output_c, kernel_s, stride=1, padding=0, bias=False):
    return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s,
                     stride=stride, padding=padding, bias=bias, groups=input_c)


class InvertedResidual(nn.Module):
    def __init__(self, input_c, output_c, stride):
        super().__init__()
        assert stride in [1, 2]
        self.stride = stride
        assert output_c % 2 == 0
        branch_features = output_c // 2
        assert (self.stride != 1) or (input_c == branch_features * 2)
        if self.stride == 2:
            self.branch1 = nn.Sequential(
                depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(input_c),
                nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.LeakyReLU(negative_slope=0.1, inplace=True))
        else:
            self.branch1 = nn.Sequential()
        self.branch2 = nn.Sequential(
            nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,
                      stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.LeakyReLU(negative_slope=0.1, inplace=True))

    def forward(self, x):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
        out = channel_shuffle(out, 2)
        return out


class ShuffleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 24, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(24),
            nn.LeakyReLU(negative_slope=0.1, inplace=True))
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.stage1 = nn.Sequential(*[InvertedResidual(24, 48, 2), InvertedResidual(48, 48, 1)])
        self.stage2 = nn.Sequential(*[InvertedResidual(48, 96, 2), InvertedResidual(96, 96, 1)])
        self.stage3 = nn.Sequential(*[InvertedResidual(96, 192, 2), InvertedResidual(192, 192, 1)])

    def forward(self, x):  # Nx3x480x640
        y = []
        x = self.conv1(x)  # Nx24x240x320
        x = self.maxpool(x)  # Nx24x120x160
        x = self.stage1(x)  # Nx48x60x80
        y.append(x)
        x = self.stage2(x)  # Nx96x30x40
        y.append(x)
        x = self.stage3(x)  # Nx192x15x20
        y.append(x)
        return tuple(y)
